{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3b44703",
   "metadata": {},
   "source": [
    "# U-Net: Training Image Segmentation Models in PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "416b0a4b",
   "metadata": {},
   "source": [
    "Copyright 2016 The BigDL Authors."
   ]
  },
  {
   "cell_type": "raw",
   "id": "a474a628",
   "metadata": {},
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf304c91",
   "metadata": {},
   "source": [
    "SparkXshards in Orca allows users to process large-scale dataset using existing Python codes in a distributed and data-parallel fashion, as shown below. This notebook is an example of training U-Net, an image segmentation model in PyTorch on Orca Estimtor and SparkXShard of images. \n",
    "\n",
    "It is adapted from [U-Net: Training Image Segmentation Models in PyTorch](https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/) on [TGS Salt Segmentation Dataset](https://www.kaggle.com/c/tgs-salt-identification-challenge). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7ba935",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use seaborn for some plots\n",
    "!pip install torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fca5adf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import necessary libraries\n",
    "from bigdl.orca.data.shard import SparkXShards\n",
    "\n",
    "import bigdl.orca.data\n",
    "from bigdl.orca import init_orca_context, stop_orca_context\n",
    "from bigdl.orca.learn.pytorch import Estimator\n",
    "import numpy as np\n",
    "from torch.nn import ConvTranspose2d\n",
    "from torch.nn import Conv2d\n",
    "from torch.nn import MaxPool2d\n",
    "from torch.nn import Module\n",
    "from torch.nn import ModuleList\n",
    "from torch.nn import ReLU\n",
    "from torchvision.transforms import CenterCrop\n",
    "from torch.nn import functional as F\n",
    "import torch.nn\n",
    "import torch.optim as optim\n",
    "from torchvision import transforms\n",
    "from IPython.display import display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "833f9193",
   "metadata": {},
   "source": [
    "Start an OrcaContext and give a bit more memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cb05ce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = init_orca_context(cluster_mode=\"local\", cores=4, memory=\"8g\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a46ad6a7",
   "metadata": {},
   "source": [
    "##  Load images into SparkXShards in parallel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "904060e4",
   "metadata": {},
   "source": [
    "Load data into data_shards, it is a SparkXshards that can be operated on in parallel. Here each record is a tuple of (image: PIL.Image.Image and mask: PIL.Image.Image), and could be operated as "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7be0d69e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "input_dir = \"../tsg_salt/images/\"\n",
    "target_dir = \"../tsg_salt/masks/\"\n",
    "data_shards = bigdl.orca.data.read_images(input_dir, \n",
    "                                          target_path=target_dir, \n",
    "                                          image_type=\".png\", \n",
    "                                          target_type=\".png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6d12d912",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGUAAABlCAIAAADbHrqYAAAfoUlEQVR4nK1dS3PcxtXtxmuAGcwMSZGmTDmyXFI5FTlRXFlknaxSzn/KJj8ivyS7bLLMJpVNyuVykirHMs3niPMGMHj0tzjswzsNzIh2vl6oSBBodJ++j3Nv34b0H//4xyiKoigKgmC1Wm02G2NMURSbzWY2m9V1naap53mbzcbzvDRNX7169eTJkyzL1ut1HMdJkvR6ve+///6bb77J89z3/V6vh/uLoqiqyvM83/eVUkqpsizLstxsNmmaHh4e/upXv/r444///Oc//+EPf/jtb3/75s2bzWYTRZHneUqppmniOO71emEY1nWd5/nd3V1RFOPx+ODgQCm1Xq+n02mv1zs9PQ3DsCiKMAyrqrq6uirLcjQa9fv9zWZT13VVVWVZ1nUdx3EQBJ7nGWOUUkmSBEEQBEFVVavVqmmauq5932+aBvOtqioIgjAMtdZlWVZVFWBWGJzW2hhTliW601qjd0xAKbXZbC4vL5fL5WKxqKrqo48+Go1GZVnmeW6MITRVVdV1jSt8FuOoqgo35Hl+c3Pz8uXL3/zmN59//vlXX3315s0bDG48HodhiPHVdb1er/M8x4Oj0ShNU611nudZlhljxuPxYDBYrVbGGK11lmUYv+d5YRjiIkbl+77v+1prrTUGprXG3DFZXG+aBk8FQQCYiqJ4uFjXddM0cmJN06ALgsXX93o9AFrXtVIKva9Wq7IssQh43BjDGzAIjAbDwj1FUbx7924ymbx8+fKLL77405/+NJ/PwzBUStV13ev1oihar9fL5XKz2WA50zQFWEVRrNdryP5wOKyqqigKyLUxJo7jpmmw9nhdEAR4qe/7cjBFUeR5junjEVzHrH3fB+JYe6hOYIwxxkACAT9+bZoGsur7fl3XnucFQRBFURiGcRyjoyAI8jxfLpd1XUdRhNVA71wx1dWqqqqqarlcvn379tNPP/3iiy/+8pe/XF5eHhwcQN7TNA2CYDqdrlaruq77/X4cx4PBIIqixWKxXq/LsozjeDgc+r6/XC6VUr1eD9AMBgOsVlEUFHOtdRiGtAwQiM1ms16vMXEIYK/XI6YAzvf9KIqARhzHARbBGAOBB3zQed/3sTJ4EjhqrXu9Xq/XwytXq9V6vcbSQbYxVkCGf9HQM36o67qua2j3ZDL55S9/+fvf//6vf/0rDNByuZxOp4PBIM/zsix7vV6/30/TNIoiWMCmabBscRzDIAK49XqtlIqiqKoqGlAKGoCgBlCNsH5S+jBIyDWVD5h6nEzTNDQZWH9Mj51ihmVZYkxhGG42m+VyiWHVtuEdxjY8KwUN70Jv7969Oz8/H41Gv/vd7waDAUxPWZar1WqxWMBCD4fDNE0BJdQQSw0FhPImSVKWJSSdQNDSw30BFDYIY7/fHw6HkFxKH9GE94OEwpIGXHDeB7DwYppDqnEYhnmew5BlWZZlGZCt65p2DQDhQZpFCjI1Alb/4uJiNpv99Kc/ff78+XfffReGYZIkkO5+v58kSb/fB1ibzQYuOAxDGApjTBRFEJDZbLZer/v9flmW0lpjLpQsjMSzLY5jGC94DzoH+lBOBLQhoI5AJTENzg2CQBGt67ooivl8jpdlWQahhfRuNhsqL50ORimNPYYC+TLGXF5e/vvf/3758uWrV68uLi583//ggw8gGrD6MKBwUlrrOI6pWbBKRVHc3d0tFgsuOSyJ7/u0p2Ak+BlPYRgYEswxRBLoo3NME/NChwGVDtDQcXKqNEO4CKWA46dAQbiki+H7KFzyIvVRa319ff3ll1+mafr8+XNQrcFg4Ps+PBeUAosXRVGv14vjmPOkUt/c3ARBAO9JUwBJpJUAlHgW84K2wvVjqERNtxruebD3wB6/wrwRJvlD0zR5nsNX8n4aBdhLDouChj/R42BMQG2xWPzrX/86PT09Pj4eDodgJ7gHbHOz2UABwTZICDa23dzcLJfLk5MTWAlOB2MG7QJq6AH2F6OC0Pm+D7GVI5RmBJjGcRzI9XcICNDhM5gkGrQ9CB7Ek/1QwSlcskkiRrN6c3NzcXHx+vXrjz766Ouvv4ZNhHUj54awKEs78VcQNEQa/X7f87w8z/kuqC3kAv0QFEIDEYMTgxCAOWHVATp+wCoG0HY5JS6gwwm4aDTqXC7CRFMlmQSGK2k0VoxrXlXVfD7v9XqvX78+Pz8H/4S+YPSYNpDCr+B9YE+DwWA8HoNtINKAWMF1AgVlAywixYEZYyDUnBd1E8smBSWQgqesSW67f6inhFxrDb9DS+cII+VOKqYSkQetr+d5RVEYY37+85//4x//AGmCFGCpsfjkUAj3IH0AK47jPM+hXGVZgjZiyRn64FkiBf2I41hZtkXuifnSFCpLsMuy3NJHKRQSLLxPAgGzAiqEoRAvxkDOSjgiTHmEiOV5vlqtXr9+fXJycnt7i/ifc8bC4OaqqhDta60Hg8FwOByNRqAajJbpH5UlLlxUSfGx9tLqQSwwNoizNDhN0wQERcoXBdJxdhIXyY+Vtd+cJMeqhHVTNoYHJ8IVuL/FYnF9ff3mzZvDw8Nvv/0WUTQIFBcMKrPZbBAkgZqlaQqFQlhz7/XtMKgNFHASMTAhUJYwDGFboIM0vpiIZAuBEqRBGmMqjnzAoQjSIUrxkX4DsMoe0CdMrMTr7du3t7e3aZqChcVxHEUR3RyQapoGUWEURQgqYbOyLIOdIptlJgfUNwxDDBK4QB7JNGnjZAxAP0MPYIx50EeKErVPAiQFja6ND1JYgKBUQEouO1fW/NHcgjFcXFx88803cRxfXV3hLePxOMsy8o/VagXBZDjZNM1isSiKwvM8qjAj5KIo4Dfkn+guARYEim6U+QUMjHkLz2aHHviX2jbz0so4BkgSPwJtRJxAEyAfIdy0vjArkLWqqqbT6bfffntycqKUury8xFiRKUKH6/UaEx4Oh4iZ8jwHIv1+PwgCKC/wxSrCu0HuIEdhGEKUfN+HDNKGUpPgCtEbgcOwAwcO+TOViH/iX9GLtI4OstJkaBE2smeIOqaHHIPW+urq6uzs7NWrV1999dV0OoXzgiph5RFRgscjhA6CoNfrIVaHj6MTVJZIgoJIk3qfzAqCLMtI1ClKxrJ3xjxA0PO8QE7DMfCdQiexI+H0tvNt/JMSCqi2lVq6dlhAxMye5/3617++vb1VSsHGNTaXgHwhs8OIWgaDARIVwAs0BZILOUJjYAgjiGcxBhoT4IjBaNtwHYmTgOmtXTJFIPgnNMmnpCHjwjooS3fBHmgKsaRwc0VR/OIXv/jyyy9hrWikkyQ5PDxEbqsoCkCJGMUYA+PNkZjtoBp6jduMMVB/mDypENoSLmZimLbh8gTyBW3sALmUVWIvuQmFUToHh6Mqqx3KWkY6FnZbFMVsNnvx4sWzZ8+m0+l0Os3z3PO8oW1IM8Bm0XJLsKg+eBeMNxIMGLa8P4oiqLyMc2n1IEwYMIzGvb2XAEkBIQ2RyoV7giAg31OCi+gu2uGooVSQIAhqm+mHatze3q5WqzRNsyxDFtf3/dFolCQJbBZmC5fXNE2WZYAP6CAlK+eJn5kaMnb3i9lqmR0CUhgYeImyrgNCvWXvpTJKOoZXyuinrZhSuaSQkisTbt6M0XsiF9I0zXQ6vby8BFFQds8FUwWtp1fVYjcLPhSDRGgljWMQBEmSMPTBNgeQ5SzQIVeReUT0A1m+10dOwJkqIXAEUKqko4O8zWGwTuMQEVEZkU1aLpe3t7d1XU8mE4THDGgwQ982CiYwhYeF4pBIA2uYOWyI4BEYRJhIaI90DsxPcO7cYQqUMMb7Z8iXUWokNZXK6AiX0wP7YYAFb41xZ1m2XC7DMFwsFginjTFIH8H9AR1MBoKgrbuE3oE6aev+uEELHCmSoBRM1DgDo9VDhoecNnAm01YoOWFpp7SlKpKgShV2cHQaWTupIHqAZTk6OoqiCLuKkCzoC8ZGK4O4DywB6S2QNZkFw/2r1UopBcjwb5IkZGRMDpPr0gUraxPuI1O9vQfRJgRtlSRSbSBoodR2ggiEU/bZ9g+N3ejdbDZxHH/wwQfff/89E2SIGZX1gLRW0ETuaUPcoiiCdGCeSHCiZ6Q3yrLs9/uj0QirBeGF0qFPiLxc3XvUOkXJadqyAUa/tIWStlBSnKccrM12OOnb/WBlmQ207+zs7OrqStI6RNHQNXRSFAVkQVne43keGCxIFq24FGSE6FAxEGBjN8akuHk2g8Rdd6VUB1+VP7R/peViYre2uwOe2BbuzFjIPklQ4M4ZWmmtsU0HlcRAET83Ni+MvXSlFHekkMkAxcdOOFm+suwcthLZRG6bwuRxtNIyarsvoawHM8Zs6SMX2duRk5ABo7y5sZvDjkBJAtEpuZJVShO2WCwODw/jOIZAAQ4YKW33FjlJhNDgsavV6t4wW6OjbR6Y6B8cHKxWq/l8jsQGM1/UEvSpRcTOBQ7UDtcm7TRNlWeDGEezpCZKNeSVNl7yCrSGd5ZluVgswFGXyyWu4B44NZZMYYZwYXB8kDgkArXgE6jZ4kqjpuru7o5GFqIAiZOCRm14wEsKV7tRoNivnDMTXo5KqtaekwMZrsv0i7x5vV4XRdHv91erFXIPYIyg6TJLhd7AJ3AF1SgACDdQsnAz6CtehKoDKbPGpiWYHSJYSqnAsTUyBnLwkoJjtjf3paBpm6JoA2S2s5LOg1wwz/OQ6kNUlKYpe0DJEMfjeR6RkgUfgBuEXtndCrIWxDpwptBieC3IKQ2l3OenQARS0aTncoDgVCm6cvKUc2ny2vLVhpV/9URUpGwe/eDggHu0xNTYLRyQDORnOEKIHsImpv3gEPE44ICIMXPdiPoJiBjRl7PQiIc6NbGNnVQrvV1OIkPINr5OD1JC+Sv0C8uAKQGvn/zkJ9PplMliSdDxA8ubQJEk44fl0lqD0JKLAbuiKLClwt1fUjw0hzBizB3sVELmGC95m2y+qJzqvEHKIG5gnlOLXJh8aVVVSZJ89tlnpELcBIRdJxnkVi5sVpIkw+Gw3+/3+30kUaWbgjI2TYMiRcgaQkswDGWTE6jO5e7RPQfk3NoWp1O4JHyOHMkrjjCyHxlCGJsCQ6PfwCsQjnz22Wd/+9vfKFbSKnu2ZJLGRNuEHfoHjjBYQA0LBvYPpgLXCfVcLpfMFEGvUVPDwdd1HbQh2IOX2jZMqkvdHHz5K/O0antHUnXFAFprFEv+7Gc/Ozg4gHExxqBcGlIJVgHRBreAQ4QYkqajT9j+2paWAGUIFPocDAYQLnIjz+aClM0A3tuv9sx3tU7iroSkSP1Fc9yCNHy7tBsXkc5XSp2dnf3zn/+kqabISDVBkARxgHPknVpUAgBHFKli2xiusNfrpWmKQJ1PkakSPs/zAjmZ9iTbTpAYkZTLi6rFG/YIpnPFs/Ez/d18Pp9MJs+fP//73/+OPL0SIQFXCNMmoNgA1yJjrmwlQL1dKQPzT8iGwyH9r7KcTtuiekx5ywU4RMxBSlt6KemP1DJH6NoXHZ8ob6PUyJkvl8uLi4vT01Nj9y+Yw6MPJbGE+EDFsM/GJSnLcj6fK5tZrqoKJf7MKcGQwcazTo/8C0lNPB446ywtrqNZnJ4najslZHzqMQouuRgz3VK4jDF5nl9fX5+dnaGEFYX1oB3kR0hAI/rhFiRcHkg8tGw2myFPDVKGzC1rAJDMwekPZSu3YAHpmiGPQduRycnvMuq6i9y2Id7TGpEyZrKFsqOsiV0sFp7nnZ6evnv3DjLIOdS20B/qw8DF9/3BYJAkSZIkuDOKIgTYyCNCTuXygKYwjOWWmBKVW/f5CdUqz92FBRHZIzgSZTQSWkf6lGAelGtp7/HIer3OsuzFixc3NzfQMoBLpLh3DWqK+miUubJiNAzD+XzO3Tljt2AoGcbWd7MKiv1ruydgbDTu+nI5+ba/k7g4Tb+Pmjh8Ap3r7X0WGjj8CbXPL168YO2gBB1SSXJPpgr54mkWz/NQzMMJ1rZMX8oyK8+BPkt9uLfWYL+jPef2zDvd6B5EZGt27xURMomUjC7Lsry9vf38889Bl+AlZVZD2701UHyQMkwPpp2R0GAwSNMUSOV5jv0kKp2sIZe+Ff8y0tziq4+ZvNqrks51SeV3Pe7EntKQwXDc3t4GQTAajVC6xKpfvb25h95gxT1bfqRsKaQxJk3Tk5MTZHtWqxWsm9zyQA+M8H171q4Rmw/dOS9OXg7l8bZ8F3ztJvWdiumEojA9T58+5cEgWSpBzlnb0y980LdFd/g1TdMPP/zwyZMn3PqFiJFzIOWPf7mo0rHc76ftsjW83r6hU8Taxl72RtVz2CytbycR0VrjzMyzZ8/Oz89h/sEYSNOUUrBuKAFjNlEphTAIeo36TZxwQ7kwkv1cMFj3RmRAK3uQE8WevV6vY3+oPehdQDgS55i896aJGluGzKe87eoiYzfNLi4uXr58CWjyPKe1UtZUAy/sg2CqkB1kLOA34jgejUbGmPF4DKbG/DWMAFGr7fE8Shak+GG/ls3hFm1QaLx3wUGwHMvYaSiNLWaSQCsbSxibApzP54iljd3lRrERlKCx58Jg1yHIKCfBtH17sFEplSTJ0dERzuyCXuCYDXf8WVZIWsdVqWX9vYSgLVyqJV9tryeVju/oBJT/Mr1D4CBQ8l3MOqDsXCkFQg8sPFsgx1QMENxsNjBS4BmIrlFpgEPNwAuzI7PHFpy2W5Y8TI7l32w2gbPmDhvihAlBG7hdbT+NcP5laws4ipezLDs+Pp7P5yzpAnukscd8EBhBQHDyEdf7/T46BzRMEAIs0DQtchIkhnidb88zddRH/yCH6G3nox+DYxsveiLyKeoaukWG75NPPvnvf//LlattiQrfXtszmOynEUVqBJppaBgv3AY3gmBLshN6ALwo4EAlWG2kHBC1CL60qH6QXlWuhBMGyVc02weYJEzKunP0/+LFi9FotFwuPVt1gjkww6PtoTUmFMEPpBvF7kZlT1sjLWGM4XGawH5AQUaRSu7Xtpvj0Z3J81eHKO1qTsAgDSKXx/GSXABKUNM04/H4+PgYFfkSBXm0qGkacFrA0e/3sZdBccPZECS8PM+D/oLHYVcJiRAgXhQF7dr9qdH989yDY/viey3Rrlc0tpzAMXkk6OBBRVEcHx9fX1+DZFG45NY/7sQPyL7CXXAZ4DqjKEI1LHf8HXbCZCzqTXBnGIZB03XKeg+IvK2xGzydN+xiD7sCI9pX7vrI22CYrq+v0zRNkkTWjfPwgbaBJNkApi1pFIhuVVXYParrejabZVmGXY9+vw//wCCJ68f4fKc+dnoAZ5IOBXOUrtk+CdPZJLGQ3FWJhDWd3c3NDTYmUNrKPKrMMfj2wF9lvwEj/Sz3KHGiBiwXhYz4akO7wF5ayaqqHqIKme12UNg/586b95AJCZYSCyDdhWP1Me71es1iVBAxbmJD42jOPHsGIsuyRuzgRVGUpimDSviE0H57hjRNCWcN3Gn+9p1S3CMddF6qpWWPeVzeI+2607/kvfCSpf1sEG4j+WKSGtl3EHrWQyAqQBk5wMVXDlarVRRFBwcHKGRVVgaVrQkGWdO2zMBz6gsbcTT0vdaHF/c4UAc1LXYGdhk41VJ/Y88FI3sF9+dtl6tU4tQw83xJkjCmMTblD46KYNv3/ZOTk/F4rJTCtz5o7OgfkKQljX/Ai8u7azJ7MGqLxq7m9NyIXUJcoZf0xOk4JqGSJMHroHRMq1OJ6rqGqkJAQE0BOuxjZb/7gfi81+s9ffoUJ3R5XKmxB19rW3PN+NR14Y34tswuG+Qgsl8r+Yiksm0JclSy7R8xntFodHBwgFqSwH7ihHlUGJ3KnvMDpWDhXGAPhSK0ROH6ZDLBlgrkqBY1/cqKLfZH7iFzxq22/ahsj7FHu5Da9Stf53Tevq22ddPPnz9HGDgYDJALrcRnWXglyzLwdaWUTAEyMQ8cZ7PZ5eXldDqliwCsDwlVm+q5lzvHZEhlfIxnpHBxnuynEzv27GhlZ3ZIdlXb+tWPP/74P//5D/SrFvVfykaayn7IDr/W4uQneBav4/MVNzc3ZVkeHR2hE+R/PLujrG3h7v2W+H673jn/Thz3uALH3jlgGVtUQZ8t30Lr0zQNPgf1ySefPH369Pb2FrQbdhoT4wkpbI4oUTXV2PCW27S1/UYYcmFKqfF4TL9v7KYvXA3DzG685K/73Z96n+Vy2i632BYxqiT6h5hgD/Hk5GQymXieh1PJaJALI77Eh34YRdZiW9+zBf04RbJarSaTSRiG2L7EYOAfWKOO3fWt89ttvByw9gDRdpp7/iqvSzHsvI3xDaTp/Pw8jmPmTjFtmgVIirYpk6qqcNIK9guekfYIFQLI6OPTkdBE1oU19rzgQ5SqthVHbRsjotnWoD3Y/dBmtqMiXpSQUSXfvn377NkzWHTsGDE2hl0HdvfiEARg8zwFB+ZV2+PJiDGTJLm7u7u+vub3n3x7Gh6HZ+B2oygK2mBxoJ2i16mPardRo8nEFWny+SdaKG+7xBCNHB2sYjqdJkmilMLhZUwY2mosj6/tR7GSJJG+HlKJvI3U3zRNZ7PZ3d1dv99/8uQJDj+y2kdrjXptd39Ijlt1KcgusPbc4ISB+1vbFEqHw8hpNpth+xpcH7a5thXgED2lFDbQmLOmmQvsGVTGodxku7u7GwwG/X6fDBTiSULn1q/+IJLVebPkwLv8Y/uiI+ZtcmtEKRKsUpIkWZb1+31tT9AiuartMV98+5Dn9iCJaZoiUwgGB+4KCcV5VBD90n6mEd8xjeM4yzLN87Wqy3NxtTtxcf60x6J3Yk026FiAdj/Sx3NPCJ8L6PV6+NYvpAxTTdMUGnd0dHRycoIQiol5BOGwU8Ph0BiDg0T4gun5+TkqOlERdnh4iPwPyq4bnq+Vi+n49bbp2dWMDQ/23yblaBey8qX4QXpJZHWOjo5GoxGiZVYKVlWFU8Z5no/H48PDQ0ynruskSfI8x0lkFIoPBgMIznQ6xSHVIAgmkwl20blrifwiynuC9mI+XjGlXPCH9/I1CQStPkGR27dMt0qTj51qrTW+uQrbbGz9sjEGGrrZbPBt68bmxIGX53mACVKWpim+NN00zdnZmdb666+/vrq6AsVdLBbIIA0Gg9FopDrz93to1P8vjejs0FkteQMzWbDcmDb+VNpPcIAoITxUNjDy7Ck1hgHa1mnCh/LD20AnCIIsy/Ap3uVyiSP4g8FAOXhRuDh0tVdGVEuJ9lixXa6zjdqeRYIlwoYFshQoHedXKBDE4PsLSDcjqwWDBd3EtjaUdz6fY08AkJ2cnEB+EUucn5/PZjPUP00mk6qqAqkO6hFGivfsuo0O7r2uQG3vXb4XMt4GKoTPNYF/ot4GogHbhPK29XptRPlvaD+4CHkEcU+S5ODgYDweA7jT01OciAOUk8kky7J3795hZ8T9PuauQb+Xee2Hrw0TcdxDgNujokvhkTOYdmRvuM9WFIXv+yCZYBsgCqzQJGlARMHyk6OjI3wCKgiCo6OjwWAwnU5935/P51dXV6vVKnCEaw98ctA/gqZ16iNFz5FrZ1RtEwG6hOusLSeUkESy88ZurwFNHtfi5iNuQDhJjxFF0Ycffnh8fIyjH9Pp9EG+nDn8Lyjskrs9JszbPvfQdtmO2mq7yQg5asSnI9CIRW2/AElQEIQTUwxsvV7LIBR49Xq90WiEL7sOh8PT09OmaVw+oVpS5vzJ2LYLxD3utbN1qiT/5GTMta0LM7aiWSa25HsbcRLEISWq64OW2Fjirh12s588eXJ4eIjD98PhcKteTi71LsvyGCB2mcLOK2ztnQQHOBJpgFXbL+d62yXFhLK2p0+VTQ1K+dX2+KASzgSuA+l8xEOTyQRnfIfD4cHBQZIkD/kJLZh9J2vdD8FjqNkuleTa7NowdxaPOwzOhBt7WMNsN9xT2c93aFsyJR83xlT2K/bKHrbCgS/QEXwGZCv/9cjJS33kyzrv3MMqHGvF1cJ8ZP+OdVPb/0VII04v1dsVYQ6+WpyccuYiEYdkoR/4TWR+7ssJ2jC1x/fexhm2lfExDkRv82S1vQwEVG17D/7biJ1TR6a0+N5fYz+OoCzibLiTUBpbNwvawfya4n4Hd43QPFE16ECwywm0lVdOaVfztg8c7Opc7t22u2XCjiOREMt+HD/AK0p8R8zYUFQpFYYhaoG01g947bJTtMF71I0jc+5pd9vpB9ReMqG3CZqcNvuXD3riyyGdY5Zwt8fDjCsFFkGlsjt1pnN/iCNrunabO3WH73bs2uPbHmWUffLXxtaUSrXgDZ3asAtE59WsIGlstQ9VUGu99e1iZ7aOku6fbVsF2t3+oAel4HR6A1oiJZSr3fOewWjRlDB/TockugbnH535OPL/QyVlz/h2tV0eRm/zCabDpD5yqE3XJnmnTeCfnJ8dr8q3NKLmOGh3pN4nL+/Vux8EsTPtPdoke6YZkgpIa9tpLtgcqydvdsCliNGUB17ry3qPme1+NDtf/8jm9LnHXDpza3uMPQK7pzmiYEQcYpzvpTmPUQv2RCrOaNReZ7r/QdWVpegEQgmA2m6u8+f9MEkvTDWUtpK3PXy4xnm+/Q7HoO4ZpdMeYwQdN9fuUFsX9CPs6a4B7Lq4q93XA+zCUnUty3sZw350TCv26pxDp1bKvz5e03fZh10mpXOd2Nzvifr2215mm3BK4ZI+q/2yTkL3I4Ri19Dbr5D9t10k94fe23/ndccceQ4KZrvJXnY54/ZfO2fS2dpv2dNV53tl66yIkkmXXSPcZXm91v914wbr3MVz6ExbPR8pXBzTj/CV/0uT+HpdJ3N+nNT/Hw7WaOWYM07lAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=101x101 at 0x7F9269B27290>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGUAAABlCAAAAABxF3ITAAAA5ElEQVR4nO2Zyw7DIAwEl6r//8v00ChpXhJrAlro7hEkj8bGJ1LGIQkATqfrTSjv08kloDKvBjVHoMR7SbmEMVzHohi9uSAsw7rEMJIdM8WUWSihtRR1mYoSGYyqy1SUwGBkXaai8IPRdZmKQg9G2MWUv6awT1nZhZWRdjGFDjd+bRdTRCnUIxN3MYUOM/4KFwKj3jFTTDGFSfny67t0ohS3bACXYpnzD29lvuDDl3Oq/TZOyL9F83r8KGUrmnfFdpiHKHfkJQ3fWN4EGrosAKDDvmSggwuAMXbfFFNMMcUUU67zAeQ5I7HCfT4XAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=101x101 at 0x7F9269CBAFD0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# take a look at the data\n",
    "from PIL import ImageOps\n",
    "image, mask = data_shards.first()\n",
    "display(image)\n",
    "mask = ImageOps.autocontrast(mask)\n",
    "display(mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "62d45e04",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# see the num of partitions of data_shards\n",
    "data_shards.num_partitions()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce183207",
   "metadata": {},
   "source": [
    "## Transformation of the images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccff8361",
   "metadata": {},
   "source": [
    "Define a train_transform function directly using methods of torchvision transforms "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c40ab5c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_transform(im):\n",
    "    features = im[0]\n",
    "    features = transforms.Resize((80, 80))(features)\n",
    "    features = transforms.ToTensor()(features)\n",
    "    features = features.numpy()\n",
    "\n",
    "    targets = im[1]\n",
    "    targets = transforms.Resize((80, 80))(targets)\n",
    "    targets = transforms.ToTensor()(targets)\n",
    "    targets = targets.numpy()\n",
    "    return features, targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "abbc976d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "data_shards = data_shards.transform_shard(train_transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aa49894",
   "metadata": {},
   "source": [
    "Stack elements of each partition into an ndarray of features and labels for efficient training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7de5bfdd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "data_shards = data_shards.stack_feature_labels()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c38da308",
   "metadata": {},
   "source": [
    "## The model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3806b914",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "Build the U-Net model model as usual. Here, model is adapted from https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "394393c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_size = (80, 80)\n",
    "\n",
    "class Block(Module):\n",
    "    def __init__(self, inChannels, outChannels):\n",
    "        super().__init__()\n",
    "        # store the convolution and RELU layers\n",
    "        self.conv1 = Conv2d(inChannels, outChannels, 3)\n",
    "        self.relu = ReLU()\n",
    "        self.conv2 = Conv2d(outChannels, outChannels, 3)\n",
    "    def forward(self, x):\n",
    "        # apply CONV => RELU => CONV block to the inputs and return it\n",
    "        return self.conv2(self.relu(self.conv1(x)))\n",
    "\n",
    "class Encoder(Module):\n",
    "    def __init__(self, channels=(3, 16, 32, 64)):\n",
    "        super().__init__()\n",
    "        # store the encoder blocks and maxpooling layer\n",
    "        self.encBlocks = ModuleList([Block(channels[i], channels[i + 1])\n",
    "                                     for i in range(len(channels) - 1)])\n",
    "        self.pool = MaxPool2d(2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # initialize an empty list to store the intermediate outputs\n",
    "        blockOutputs = []\n",
    "        # loop through the encoder blocks\n",
    "        for block in self.encBlocks:\n",
    "            # pass the inputs through the current encoder block, store\n",
    "            # the outputs, and then apply maxpooling on the output\n",
    "            x = block(x)\n",
    "            blockOutputs.append(x)\n",
    "            x = self.pool(x)\n",
    "        # return the list containing the intermediate outputs\n",
    "        return blockOutputs\n",
    "\n",
    "\n",
    "class Decoder(Module):\n",
    "    def __init__(self, channels=(64, 32, 16)):\n",
    "        super().__init__()\n",
    "        # initialize the number of channels, upsampler blocks, and decoder blocks\n",
    "        self.channels = channels\n",
    "        self.upconvs = ModuleList([ConvTranspose2d(channels[i], channels[i + 1], 2, 2)\n",
    "                                   for i in range(len(channels) - 1)])\n",
    "        self.dec_blocks = ModuleList([Block(channels[i], channels[i + 1])\n",
    "                                      for i in range(len(channels) - 1)])\n",
    "\n",
    "    def forward(self, x, encFeatures):\n",
    "        # loop through the number of channels\n",
    "        for i in range(len(self.channels) - 1):\n",
    "            # pass the inputs through the upsampler blocks\n",
    "            x = self.upconvs[i](x)\n",
    "            # crop the current features from the encoder blocks,\n",
    "            # concatenate them with the current upsampled features,\n",
    "            # and pass the concatenated output through the current\n",
    "            # decoder block\n",
    "            encFeat = self.crop(encFeatures[i], x)\n",
    "            x = torch.cat([x, encFeat], dim=1)\n",
    "            x = self.dec_blocks[i](x)\n",
    "        # return the final decoder output\n",
    "        return x\n",
    "\n",
    "    def crop(self, encFeatures, x):\n",
    "        # grab the dimensions of the inputs, and crop the encoder\n",
    "\t\t# features to match the dimensions\n",
    "        (_, _, H, W) = x.shape\n",
    "        encFeatures = CenterCrop([H, W])(encFeatures)\n",
    "        # return the cropped features\n",
    "        return encFeatures\n",
    "\n",
    "class UNet(Module):\n",
    "    def __init__(self, encChannels=(3, 16, 32, 64),\n",
    "                 decChannels=(64, 32, 16),\n",
    "                 nbClasses=1, retainDim=True,\n",
    "                 outSize = (img_size[0], img_size[1])):\n",
    "        super().__init__()\n",
    "        # initialize the encoder and decoder\n",
    "        self.encoder = Encoder(encChannels)\n",
    "        self.decoder = Decoder(decChannels)\n",
    "        # initialize the regression head and store the class variables\n",
    "        self.head = Conv2d(decChannels[-1], nbClasses, 1)\n",
    "        self.retainDim = retainDim\n",
    "        self.outSize = outSize\n",
    "\n",
    "    def forward(self, x):\n",
    "        # grab the features from the encoder\n",
    "        encFeatures = self.encoder(x)\n",
    "        # pass the encoder features through decoder making sure that\n",
    "        # their dimensions are suited for concatenation\n",
    "        decFeatures = self.decoder(encFeatures[::-1][0],\n",
    "            encFeatures[::-1][1:])\n",
    "        # pass the decoder features through the regression head to\n",
    "        # obtain the segmentation mask\n",
    "        map = self.head(decFeatures)\n",
    "        # check to see if we are retaining the original output\n",
    "        # dimensions and if so, then resize the output to match them\n",
    "        if self.retainDim:\n",
    "            map = F.interpolate(map, self.outSize)\n",
    "        # return the segmentation map\n",
    "        return map"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6aa536a1",
   "metadata": {},
   "source": [
    "define a model_creator for Orca Estimator and show the summary of model structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "88dd965d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_creator(config):\n",
    "    model = UNet()\n",
    "    model.train()\n",
    "    print(model)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15fdcd7f",
   "metadata": {},
   "source": [
    "define the loss function, optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c1de6bce",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = torch.nn.BCEWithLogitsLoss()\n",
    "\n",
    "def optimizer_creator(model, config):\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "    return optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88826381",
   "metadata": {},
   "source": [
    "define an Orca Estimator and train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "51dd7ebf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UNet(\n",
      "  (encoder): Encoder(\n",
      "    (encBlocks): ModuleList(\n",
      "      (0): Block(\n",
      "        (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))\n",
      "        (relu): ReLU()\n",
      "        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))\n",
      "      )\n",
      "      (1): Block(\n",
      "        (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "        (relu): ReLU()\n",
      "        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "      )\n",
      "      (2): Block(\n",
      "        (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
      "        (relu): ReLU()\n",
      "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
      "      )\n",
      "    )\n",
      "    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (decoder): Decoder(\n",
      "    (upconvs): ModuleList(\n",
      "      (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))\n",
      "      (1): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2))\n",
      "    )\n",
      "    (dec_blocks): ModuleList(\n",
      "      (0): Block(\n",
      "        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "        (relu): ReLU()\n",
      "        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "      )\n",
      "      (1): Block(\n",
      "        (conv1): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))\n",
      "        (relu): ReLU()\n",
      "        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (head): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "orca_estimator = Estimator.from_torch(model=model_creator,\n",
    "                                      optimizer=optimizer_creator,\n",
    "                                      loss=criterion,\n",
    "                                      metrics=[criterion],\n",
    "                                      backend=\"spark\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52c640cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 50\n",
    "BATCH_SIZE = 32\n",
    "orca_estimator.fit(data=data_shards, epochs=EPOCHS, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9c18a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_orca_context()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37tf2",
   "language": "python",
   "name": "py37tf2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
